Skip to content

Conversation

@weixiao-huang
Copy link
Contributor

This MR makes that process_weights_after_loading could be reused in fp8 quantization

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors weight processing for FP8 quantization to support weight updates, primarily by introducing a _wrap_parameter_or_copy helper function. This is a good change for compatibility with CUDA graphs. The change in kv_cache.py also improves robustness by ensuring quantization scales are always present. However, I've found a critical issue in Fp8MoEMethod.process_weights_after_loading where a parameter is not correctly unwrapped, leading to a no-op update and incorrect behavior in certain code paths. I've also suggested an improvement in kv_cache.py to make the code more robust by removing some overly strict assertions.

Comment on lines +761 to +766
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In the else branch of the conditional starting at line 733, the variables w2_weight and w2_weight_scale_inv are assigned torch.nn.Parameter objects on lines 755-756, instead of their underlying tensor data. Consequently, these calls to _wrap_parameter_or_copy become no-ops due to self-copying, which is likely not the intended behavior and can lead to incorrect weight updates.

This is inconsistent with how w13_weight is handled in the same block, which correctly uses .data. To fix this, you should modify lines 755-756 to extract the tensor data, like so:

# In vllm/model_executor/layers/quantization/fp8.py, lines 755-756
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data

Since the fix is outside the diff, I'm placing this comment here to highlight this critical issue.

Comment on lines +53 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These assertions could make the code brittle. If another part of the codebase modifies these attributes partially (e.g., removes q_scale but not k_scale), these assertions will fail. The main goal here is to ensure all weights are present if any are missing. Simply checking for q_scale and then creating all weights is sufficient and more robust against unforeseen state changes.

@youkaichao youkaichao marked this pull request as draft September 9, 2025 07:40
…dd missing scale attributes

Signed-off-by: huangweixiao <huangweixiao@msh.team>
@faresobeid
Copy link

Update?

return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m


def _wrap_parameter_or_copy(layer: torch.nn.Module, name: str,
Copy link
Contributor

@kylesayrs kylesayrs Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of torch28 (and torch27 using this pr) torch compile supports parameter subclasses. Given this, all that should be required is that the newly (maybe padded) weight is updated, a new Parameter class need not be created.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. In fact, the second branch in this code never gets triggered. All that is needed is to clean up fp8.py from param = Parameter(...) statement that drop the weight loaders.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants